"""
Structural Gravity Feasibility Assessment.

Tests whether standard structural gravity estimation (origin×year + destination×year FE)
is feasible for identifying bilateral demographic distance effects.

Key concern: ΔZ_k = f(Z_k_i) - f(Z_k_j) is a function of country-level variables.
Origin×year + destination×year FE perfectly absorb country×year variation,
which should mechanically absorb ΔZ.

This script:
1. Verifies the absorption claim by regressing ΔZ on origin×year + destination×year dummies
2. Tests partial structural gravity: reporter×year FE only
3. Tests three-way additive FE: reporter + partner + year (not interacted)
4. Checks whether KAOPEN interactions survive these specifications
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import statsmodels.api as sm
from scipy import stats

BASE_DIR = Path("/mnt/c/demographics_capital_flows/gravity_bilateral_followup")
PROCESSED_DIR = BASE_DIR / "data" / "processed"
OUTPUT_DIR = BASE_DIR / "output" / "tables"

sys.path.insert(0, str(BASE_DIR))
from src.model import PanelGLS


def load_panel():
    """Load bilateral panel."""
    df = pd.read_csv(PROCESSED_DIR / "bilateral_panel.csv")
    print(f"Loaded: {len(df):,} obs, {df['pair_id'].nunique():,} pairs")
    print(f"Years: {sorted(df['year'].unique())}")
    print(f"Origins (iso_o): {df['iso_o'].nunique()}")
    print(f"Destinations (iso_d): {df['iso_d'].nunique()}")
    return df


def demean(y, groups):
    """Remove group means from y. groups is a Series or array of group labels."""
    df_tmp = pd.DataFrame({'y': y, 'g': groups})
    means = df_tmp.groupby('g')['y'].transform('mean')
    return y - means.values


def iterative_demean(y, group1, group2, max_iter=500, tol=1e-8):
    """
    Alternating projection to remove two-way fixed effects.
    Demean by group1, then group2, repeat until convergence.
    """
    resid = y.copy().astype(float)
    for i in range(max_iter):
        old = resid.copy()
        resid = demean(resid, group1)
        resid = demean(resid, group2)
        change = np.sqrt(np.mean((resid - old)**2))
        if change < tol:
            break
    return resid


def ols_with_dummies(y, X, dummy_groups, dummy_names=None):
    """
    OLS regression with high-dimensional FE via Frisch-Waugh-Lovell (iterative demeaning).
    Returns R² of full model (including FE).

    For single group: simple demeaning.
    For two groups: iterative demeaning.
    """
    y = np.asarray(y, dtype=float)
    X = np.asarray(X, dtype=float)

    if len(dummy_groups) == 1:
        y_dm = demean(y, dummy_groups[0])
        X_dm = np.column_stack([demean(X[:, j], dummy_groups[0]) for j in range(X.shape[1])])
    elif len(dummy_groups) == 2:
        y_dm = iterative_demean(y, dummy_groups[0], dummy_groups[1])
        X_dm = np.column_stack([
            iterative_demean(X[:, j], dummy_groups[0], dummy_groups[1])
            for j in range(X.shape[1])
        ])
    elif len(dummy_groups) == 3:
        # Three-way: iterative across all three
        y_dm = y.copy().astype(float)
        X_dm = X.copy().astype(float)
        for _ in range(500):
            old = y_dm.copy()
            for g in dummy_groups:
                y_dm = demean(y_dm, g)
            change = np.sqrt(np.mean((y_dm - old)**2))
            if change < 1e-8:
                break
        for j in range(X.shape[1]):
            col = X[:, j].copy().astype(float)
            for _ in range(500):
                old_c = col.copy()
                for g in dummy_groups:
                    col = demean(col, g)
                if np.sqrt(np.mean((col - old_c)**2)) < 1e-8:
                    break
            X_dm[:, j] = col
    else:
        raise ValueError("Max 3 groups supported")

    # OLS on demeaned data (no constant needed)
    if X_dm.shape[1] == 0:
        # No regressors beyond FE — just compute FE R²
        ss_tot = np.sum((y - np.mean(y))**2)
        ss_res = np.sum(y_dm**2)
        r2_fe = 1 - ss_res / ss_tot
        return None, None, None, r2_fe, len(y)

    try:
        beta = np.linalg.lstsq(X_dm, y_dm, rcond=None)[0]
    except np.linalg.LinAlgError:
        return None, None, None, None, len(y)

    resid_dm = y_dm - X_dm @ beta

    # R² of the full model (FE + regressors) relative to grand mean
    ss_tot = np.sum((y - np.mean(y))**2)

    # Residuals in levels: we need to compute fitted = FE + X*beta
    # But for R² we use: R² = 1 - SS_res / SS_tot where SS_res is from the demeaned residuals
    # projected back. The demeaned residual IS the full-model residual.
    ss_res = np.sum(resid_dm**2)
    r2_full = 1 - ss_res / ss_tot

    # Standard errors (clustered would be better, but OLS SE for now)
    n, k = X_dm.shape
    dof = max(n - k, 1)
    sigma2 = np.sum(resid_dm**2) / dof
    try:
        XtX_inv = np.linalg.inv(X_dm.T @ X_dm)
        se = np.sqrt(np.diag(sigma2 * XtX_inv))
    except np.linalg.LinAlgError:
        se = np.full(k, np.nan)

    t_vals = beta / se
    p_vals = 2 * (1 - stats.t.cdf(np.abs(t_vals), dof))

    return beta, se, p_vals, r2_full, n


def main():
    df = load_panel()
    results_lines = []

    def log(msg):
        print(msg)
        results_lines.append(msg)

    log("# Structural Gravity Feasibility Assessment")
    log("")
    log("## Background")
    log("")
    log("The structural gravity literature (Head & Mayer 2014, Anderson & van Wincoop 2003)")
    log("uses origin x year and destination x year fixed effects to control for multilateral")
    log("resistance terms. This absorbs all country-level time-varying unobservables.")
    log("")
    log("Our key variable, bilateral demographic distance dZ_k = f(Z_k_i) - f(Z_k_j),")
    log("is constructed from country-level demographic polynomials. The concern is that")
    log("origin x year + destination x year FE mechanically absorb dZ_k since it is a")
    log("linear combination of origin-level and destination-level variables.")
    log("")

    # ===================================================================
    # 1. ABSORPTION TEST: How much of dZ variation do country×year FE absorb?
    # ===================================================================
    log("## Test 1: Absorption of dZ by Country x Year Fixed Effects")
    log("")

    # Work with estimation sample (non-missing dZ and dep var)
    dep_var = 'log_portfolio_total'
    est_cols = ['dZ_1', 'dZ_2', 'dZ_3', dep_var, 'iso_o', 'iso_d', 'year',
                'log_dist', 'contiguity', 'common_lang_official', 'colonial_ties',
                'log_gdp_product', 'pair_id', 'kaopen_j',
                'dZ_1_x_kaopen_j', 'dZ_2_x_kaopen_j', 'dZ_3_x_kaopen_j',
                'Z_1_i', 'Z_2_i', 'Z_3_i', 'Z_1_j', 'Z_2_j', 'Z_3_j']
    est = df.dropna(subset=[c for c in est_cols if c in df.columns]).copy()

    log(f"Estimation sample: {len(est):,} obs, {est['iso_o'].nunique()} origins, "
        f"{est['iso_d'].nunique()} destinations, {est['year'].nunique()} years")
    log("")

    # Create interaction group labels
    est['ot'] = est['iso_o'].astype(str) + '_' + est['year'].astype(str)
    est['dt'] = est['iso_d'].astype(str) + '_' + est['year'].astype(str)

    log("### Regress each dZ_k on origin x year + destination x year FE")
    log("")
    log("| Variable | R-squared (o×t + d×t FE) | Interpretation |")
    log("|----------|--------------------------|----------------|")

    for dz in ['dZ_1', 'dZ_2', 'dZ_3']:
        y = est[dz].values
        # Use iterative demeaning to project out both sets of FE
        resid = iterative_demean(y, est['ot'].values, est['dt'].values)
        ss_tot = np.sum((y - np.mean(y))**2)
        ss_res = np.sum(resid**2)
        r2 = 1 - ss_res / ss_tot
        interp = "**Fully absorbed**" if r2 > 0.999 else "Partially absorbed" if r2 > 0.9 else "Residual variation exists"
        log(f"| {dz} | {r2:.6f} | {interp} |")

    log("")

    # Also check dZ × KAOPEN_j interactions
    log("### Regress dZ_k x KAOPEN_j on origin x year + destination x year FE")
    log("")
    log("| Variable | R-squared (o×t + d×t FE) | Interpretation |")
    log("|----------|--------------------------|----------------|")

    for dz_k in ['dZ_1_x_kaopen_j', 'dZ_2_x_kaopen_j', 'dZ_3_x_kaopen_j']:
        y = est[dz_k].values
        resid = iterative_demean(y, est['ot'].values, est['dt'].values)
        ss_tot = np.sum((y - np.mean(y))**2)
        ss_res = np.sum(resid**2)
        r2 = 1 - ss_res / ss_tot
        interp = "**Fully absorbed**" if r2 > 0.999 else "Partially absorbed" if r2 > 0.9 else "Residual variation exists"
        log(f"| {dz_k} | {r2:.6f} | {interp} |")

    log("")

    # Analytical explanation
    log("### Analytical Note")
    log("")
    log("Since dZ_k = Z_k(reporter) - Z_k(partner), and Z_k is purely country×year level,")
    log("origin×year FE absorb Z_k(reporter) exactly, and destination×year FE absorb")
    log("Z_k(partner) exactly. Therefore dZ_k is *perfectly* collinear with the two-way")
    log("country×year FE. R-squared should be exactly 1.0 (up to numerical precision).")
    log("")
    log("For dZ_k × KAOPEN_j: KAOPEN_j is destination×year level, so dZ_k × KAOPEN_j")
    log("= Z_k_i × KAOPEN_j - Z_k_j × KAOPEN_j. The second term is absorbed by d×t FE,")
    log("but Z_k_i × KAOPEN_j is NOT absorbed by either o×t or d×t FE alone (it is a")
    log("cross-product of origin-level and destination-level variables). So the interaction")
    log("may have residual variation.")
    log("")

    # ===================================================================
    # 2. PARTIAL STRUCTURAL GRAVITY: Reporter × Year FE only
    # ===================================================================
    log("## Test 2: Partial Structural Gravity (Reporter x Year FE Only)")
    log("")
    log("This absorbs all reporter-level variation (including Z_k_i) but preserves")
    log("partner-level variation. The identifying variation comes from partner demographics")
    log("Z_k_j (and partner KAOPEN_j).")
    log("")
    log("Under o×t FE, dZ_k = [Z_k_i absorbed by FE] - Z_k_j. So only the partner")
    log("component survives. The coefficient on dZ_k estimates the effect of partner")
    log("demographics on bilateral flows, controlling for all reporter×year confounds.")
    log("")

    # Gravity variables that survive reporter×year FE:
    # log_dist, contiguity, common_lang, colonial_ties: pair-level, survive all FE
    # log_gdp_product: absorbed partially (reporter GDP absorbed, partner GDP survives)

    gravity_vars = ['log_dist', 'contiguity', 'common_lang_official', 'colonial_ties']
    # log_gdp_product is partly absorbed; skip it since reporter GDP is in o×t FE

    # Model 2b equivalent: gravity + dZ with reporter×year FE
    demo_vars = ['dZ_1', 'dZ_2', 'dZ_3']

    log("### Model 2b with Reporter x Year FE")
    log("")

    all_x = gravity_vars + demo_vars
    sub = est.dropna(subset=[dep_var] + all_x).copy()

    y_arr = sub[dep_var].values
    X_arr = sub[all_x].values

    beta, se, pvals, r2, n = ols_with_dummies(
        y_arr, X_arr, [sub['ot'].values]
    )

    log(f"N = {n:,}, R-squared = {r2:.4f}")
    log("")
    log("| Variable | Coefficient | Std Error | p-value | Sig |")
    log("|----------|-------------|-----------|---------|-----|")
    if beta is not None:
        for i, v in enumerate(all_x):
            sig = '***' if pvals[i] < 0.01 else '**' if pvals[i] < 0.05 else '*' if pvals[i] < 0.1 else ''
            log(f"| {v} | {beta[i]:.4f} | {se[i]:.4f} | {pvals[i]:.4f} | {sig} |")
    log("")

    # Model 2c equivalent: + KAOPEN interactions
    log("### Model 2c with Reporter x Year FE (+ KAOPEN interactions)")
    log("")

    kaopen_vars = ['kaopen_j', 'dZ_1_x_kaopen_j', 'dZ_2_x_kaopen_j', 'dZ_3_x_kaopen_j']
    all_x_c = gravity_vars + demo_vars + kaopen_vars
    sub_c = est.dropna(subset=[dep_var] + all_x_c).copy()

    y_arr = sub_c[dep_var].values
    X_arr = sub_c[all_x_c].values

    beta_c, se_c, pvals_c, r2_c, n_c = ols_with_dummies(
        y_arr, X_arr, [sub_c['ot'].values]
    )

    log(f"N = {n_c:,}, R-squared = {r2_c:.4f}")
    log("")
    log("| Variable | Coefficient | Std Error | p-value | Sig |")
    log("|----------|-------------|-----------|---------|-----|")
    if beta_c is not None:
        for i, v in enumerate(all_x_c):
            sig = '***' if pvals_c[i] < 0.01 else '**' if pvals_c[i] < 0.05 else '*' if pvals_c[i] < 0.1 else ''
            log(f"| {v} | {beta_c[i]:.4f} | {se_c[i]:.4f} | {pvals_c[i]:.4f} | {sig} |")
    log("")

    # ===================================================================
    # 3. THREE-WAY ADDITIVE FE: Reporter + Partner + Year
    # ===================================================================
    log("## Test 3: Three-Way Additive FE (Reporter + Partner + Year)")
    log("")
    log("Uses reporter FE + partner FE + year FE (not interacted). This absorbs")
    log("time-invariant country characteristics and common time trends, but preserves")
    log("country-specific time variation in demographics.")
    log("")

    # Model 2b with three-way additive FE
    log("### Model 2b with Reporter + Partner + Year FE")
    log("")

    all_x = gravity_vars + demo_vars
    sub = est.dropna(subset=[dep_var] + all_x).copy()

    y_arr = sub[dep_var].values
    X_arr = sub[all_x].values

    beta_3, se_3, pvals_3, r2_3, n_3 = ols_with_dummies(
        y_arr, X_arr,
        [sub['iso_o'].values, sub['iso_d'].values, sub['year'].values]
    )

    log(f"N = {n_3:,}, R-squared = {r2_3:.4f}")
    log("")
    log("| Variable | Coefficient | Std Error | p-value | Sig |")
    log("|----------|-------------|-----------|---------|-----|")
    if beta_3 is not None:
        for i, v in enumerate(all_x):
            sig = '***' if pvals_3[i] < 0.01 else '**' if pvals_3[i] < 0.05 else '*' if pvals_3[i] < 0.1 else ''
            log(f"| {v} | {beta_3[i]:.4f} | {se_3[i]:.4f} | {pvals_3[i]:.4f} | {sig} |")
    log("")

    # Model 2c with three-way additive FE
    log("### Model 2c with Reporter + Partner + Year FE (+ KAOPEN interactions)")
    log("")

    all_x_c = gravity_vars + demo_vars + kaopen_vars
    sub_c = est.dropna(subset=[dep_var] + all_x_c).copy()

    y_arr = sub_c[dep_var].values
    X_arr = sub_c[all_x_c].values

    beta_3c, se_3c, pvals_3c, r2_3c, n_3c = ols_with_dummies(
        y_arr, X_arr,
        [sub_c['iso_o'].values, sub_c['iso_d'].values, sub_c['year'].values]
    )

    log(f"N = {n_3c:,}, R-squared = {r2_3c:.4f}")
    log("")
    log("| Variable | Coefficient | Std Error | p-value | Sig |")
    log("|----------|-------------|-----------|---------|-----|")
    if beta_3c is not None:
        for i, v in enumerate(all_x_c):
            sig = '***' if pvals_3c[i] < 0.01 else '**' if pvals_3c[i] < 0.05 else '*' if pvals_3c[i] < 0.1 else ''
            log(f"| {v} | {beta_3c[i]:.4f} | {se_3c[i]:.4f} | {pvals_3c[i]:.4f} | {sig} |")
    log("")

    # ===================================================================
    # 4. COMPARISON WITH BASELINE (pooled GLS from phase2)
    # ===================================================================
    log("## Test 4: Comparison with Baseline Pooled GLS")
    log("")
    log("For reference, re-estimate the pooled specification (no FE beyond year dummies)")
    log("on the same sample, to compare coefficient magnitudes and significance.")
    log("")

    # Pooled OLS (+ year dummies) on same sample
    all_x_pool = gravity_vars + demo_vars + ['log_gdp_product']
    sub_pool = est.dropna(subset=[dep_var] + all_x_pool).copy()

    # Year dummies
    years = sorted(sub_pool['year'].unique())
    for y in years[1:]:
        sub_pool[f'yr_{y}'] = (sub_pool['year'] == y).astype(int)
    yr_cols = [f'yr_{y}' for y in years[1:]]

    X_pool = sub_pool[all_x_pool + yr_cols].values
    y_pool = sub_pool[dep_var].values
    X_pool_c = sm.add_constant(X_pool)
    ols_pool = sm.OLS(y_pool, X_pool_c).fit()

    log(f"Pooled OLS: N = {len(sub_pool):,}, R-squared = {ols_pool.rsquared:.4f}")
    log("")
    log("| Variable | Coefficient | Std Error | p-value | Sig |")
    log("|----------|-------------|-----------|---------|-----|")
    for i, v in enumerate(all_x_pool):
        idx = i + 1  # skip constant
        sig = '***' if ols_pool.pvalues[idx] < 0.01 else '**' if ols_pool.pvalues[idx] < 0.05 else '*' if ols_pool.pvalues[idx] < 0.1 else ''
        log(f"| {v} | {ols_pool.params[idx]:.4f} | {ols_pool.bse[idx]:.4f} | {ols_pool.pvalues[idx]:.4f} | {sig} |")
    log("")

    # Pooled OLS with KAOPEN interactions
    all_x_pool_c = gravity_vars + demo_vars + ['log_gdp_product'] + kaopen_vars
    sub_pool_c = est.dropna(subset=[dep_var] + all_x_pool_c).copy()
    for y in years[1:]:
        sub_pool_c[f'yr_{y}'] = (sub_pool_c['year'] == y).astype(int)

    X_pool_c2 = sm.add_constant(sub_pool_c[all_x_pool_c + yr_cols].values)
    y_pool_c2 = sub_pool_c[dep_var].values
    ols_pool_c = sm.OLS(y_pool_c2, X_pool_c2).fit()

    log(f"Pooled OLS + KAOPEN: N = {len(sub_pool_c):,}, R-squared = {ols_pool_c.rsquared:.4f}")
    log("")
    log("| Variable | Coefficient | Std Error | p-value | Sig |")
    log("|----------|-------------|-----------|---------|-----|")
    for i, v in enumerate(all_x_pool_c):
        idx = i + 1
        sig = '***' if ols_pool_c.pvalues[idx] < 0.01 else '**' if ols_pool_c.pvalues[idx] < 0.05 else '*' if ols_pool_c.pvalues[idx] < 0.1 else ''
        log(f"| {v} | {ols_pool_c.params[idx]:.4f} | {ols_pool_c.bse[idx]:.4f} | {ols_pool_c.pvalues[idx]:.4f} | {sig} |")
    log("")

    # ===================================================================
    # 5. SUMMARY AND RECOMMENDATIONS
    # ===================================================================
    log("## Summary and Recommendations")
    log("")
    log("### Key Findings")
    log("")

    # Save markdown
    outfile = OUTPUT_DIR / "structural_gravity_assessment.md"
    with open(outfile, 'w') as f:
        f.write('\n'.join(results_lines))

    print(f"\nSaved: {outfile}")
    return results_lines


if __name__ == "__main__":
    main()
